{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.layers import Reshape\n",
    "from keras.layers.core import Activation\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras.layers.convolutional import UpSampling2D\n",
    "from keras.layers.convolutional import Conv2D, MaxPooling2D\n",
    "from keras.layers.core import Flatten\n",
    "from keras.optimizers import SGD\n",
    "from keras.datasets import mnist\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import argparse\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "整个训练过程可以说判别器 D 和生成器 G 对价值函数 V(G,D) 进行了极小极大化博弈：\n",
    "![gan_tf_keras1.png](./Res/gan_tf_keras1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 最优生成器\n",
    "\n",
    "因为最优的判别器D(x)=P_data/(P_data+P_G)，我们将其代入V（G,D）可得：\n",
    "![gan_tf_keras2.png](./Res/gan_tf_keras2.png)\n",
    "该积分进行变换得：\n",
    "![gan_tf_keras3.png](./Res/gan_tf_keras3.png)\n",
    "假设存在两个分布 P 和 Q，且这两个分布的平均分布 M=(P+Q)/2，那么这两个分布之间的 JS 散度为 P 与 M 之间的 KL 散度加上 Q 与 M 之间的 KL 散度再除以 2；因此可化为：\n",
    "![gan_tf_keras4.png](./Res/gan_tf_keras4.png)\n",
    "\n",
    "JS 散度的取值为 0 到 log2。若两个分布完全没有交集，那么 JS 散度取最大值 log2；若两个分布完全一样，那么 JS 散度取最小值 0。当 P_G=P_data 时，JSD(P_data||P_G) 为 0。综上所述，生成分布当且仅当等于真实数据分布式时，我们可以取得最优生成器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator_model():\n",
    "    #下面搭建生成器的架构，首先导入序贯模型（sequential），即多个网络层的线性堆叠\n",
    "    model = Sequential()\n",
    "    #添加一个全连接层，输入为100维向量，输出为1024维\n",
    "    model.add(Dense(input_dim=100, output_dim=1024))\n",
    "    #添加一个激活函数tanh\n",
    "    model.add(Activation('tanh'))\n",
    "    #添加一个全连接层，输出为128×7×7维度\n",
    "    model.add(Dense(128*7*7))\n",
    "    #添加一个批量归一化层，该层在每个batch上将前一层的激活值重新规范化，即使得其输出数据的均值接近0，其标准差接近1\n",
    "    model.add(BatchNormalization())\n",
    "    model.add(Activation('tanh'))\n",
    "    #Reshape层用来将输入shape转换为特定的shape，将含有128*7*7个元素的向量转化为7×7×128张量\n",
    "    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))\n",
    "    #2维上采样层，即将数据的行和列分别重复2次\n",
    "    model.add(UpSampling2D(size=(2, 2)))\n",
    "    #添加一个2维卷积层，卷积核大小为5×5，激活函数为tanh，共64个卷积核，并采用padding以保持图像尺寸不变\n",
    "    model.add(Conv2D(64, (5, 5), padding='same'))\n",
    "    model.add(Activation('tanh'))\n",
    "    model.add(UpSampling2D(size=(2, 2)))\n",
    "    #卷积核设为1即输出图像的维度\n",
    "    model.add(Conv2D(1, (5, 5), padding='same'))\n",
    "    model.add(Activation('tanh'))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 最优判别器\n",
    "\n",
    "原论文中价值函数可写为在 x 上的积分，即将数学期望展开为积分形式：\n",
    "![gan_tf_keras5.png](./Res/gan_tf_keras5.png)\n",
    "其实求积分的最大值可以转化为求被积函数的最大值。而求被积函数的最大值是为了求得最优判别器 D，因此不涉及判别器的项都可以看作为常数项。\n",
    "若令判别器 D(x) 等于 y，那么被积函数可以写为：\n",
    "![gan_tf_keras6.png](./Res/gan_tf_keras6.png)\n",
    "为了找到最优的极值点，如果 a+b≠0，我们可以用以下一阶导求解：\n",
    "![gan_tf_keras6.png](./Res/gan_tf_keras7.png)\n",
    "因此，最优判别器D(x)=P_data/(P_data+P_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def discriminator_model():\n",
    "    #下面搭建判别器架构，同样采用序贯模型\n",
    "    model = Sequential()\n",
    "    \n",
    "    #添加2维卷积层，卷积核大小为5×5，激活函数为tanh，输入shape在‘channels_first’模式下为（samples,channels，rows，cols）\n",
    "    #在‘channels_last’模式下为（samples,rows,cols,channels），输出为64维\n",
    "    model.add(\n",
    "            Conv2D(64, (5, 5),\n",
    "            padding='same',\n",
    "            input_shape=(28, 28, 1))\n",
    "            )\n",
    "    model.add(Activation('tanh'))\n",
    "    #为空域信号施加最大值池化，pool_size取（2，2）代表使图片在两个维度上均变为原长的一半\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(Conv2D(128, (5, 5)))\n",
    "    model.add(Activation('tanh'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    #Flatten层把多维输入一维化，常用在从卷积层到全连接层的过渡\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(1024))\n",
    "    model.add(Activation('tanh'))\n",
    "    #一个结点进行二值分类，并采用sigmoid函数的输出作为概念\n",
    "    model.add(Dense(1))\n",
    "    model.add(Activation('sigmoid'))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator_containing_discriminator(g, d):\n",
    "    #将前面定义的生成器架构和判别器架构组拼接成一个大的神经网络，用于判别生成的图片\n",
    "    model = Sequential()\n",
    "    #先添加生成器架构，再令d不可训练，即固定d\n",
    "    #因此在给定d的情况下训练生成器，即通过将生成的结果投入到判别器进行辨别而优化生成器\n",
    "    model.add(g)\n",
    "    d.trainable = False\n",
    "    model.add(d)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def combine_images(generated_images):\n",
    "    #生成图片拼接\n",
    "    num = generated_images.shape[0]\n",
    "    width = int(math.sqrt(num))\n",
    "    height = int(math.ceil(float(num)/width))\n",
    "    shape = generated_images.shape[1:3]\n",
    "    image = np.zeros((height*shape[0], width*shape[1]),\n",
    "                     dtype=generated_images.dtype)\n",
    "    for index, img in enumerate(generated_images):\n",
    "        i = int(index/width)\n",
    "        j = index % width\n",
    "        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \\\n",
    "            img[:, :, 0]\n",
    "    return image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 对于每一次迭代：\n",
    "\n",
    "-  从真实数据分布 P_data 抽取 m 个样本\n",
    "-  从先验分布 P_prior(z) 抽取 m 个噪声样本\n",
    "-  将噪声样本投入 G 而生成数据，即x^tilde = G(Z^i)；通过最大化 V 的近似而更新判别器参数θ_d\n",
    "\n",
    "以上是学习判别器 D 的过程。因为学习 D 的过程是计算 JS 散度的过程，并且我们希望能最大化价值函数，所以该步骤会重复 k 次。\n",
    "\n",
    "-  从先验分布 P_prior(z) 中抽取另外 m 个噪声样本 {z^1,...,z^m}\n",
    "-  通过极小化 V^tilde 而更新生成器参数θ_g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(BATCH_SIZE):\n",
    "    \n",
    "    # 国内好像不能直接导入数据集，我们试了几次都不行，后来将数据集下载到本地'~/.keras/datasets/'，也就是当前目录（我的是用户文件夹下）下的.keras文件夹中。\n",
    "    #下载的地址为：https://s3.amazonaws.com/img-datasets/mnist.npz\n",
    "    (X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "    #iamge_data_format选择\"channels_last\"或\"channels_first\"，该选项指定了Keras将要使用的维度顺序。\n",
    "    #\"channels_first\"假定2D数据的维度顺序为(channels, rows, cols)，3D数据的维度顺序为(channels, conv_dim1, conv_dim2, conv_dim3)\n",
    "    \n",
    "    #转换字段类型，并将数据导入变量中\n",
    "    X_train = (X_train.astype(np.float32) - 127.5)/127.5\n",
    "    X_train = X_train[:, :, :, None]\n",
    "    X_test = X_test[:, :, :, None]\n",
    "    # X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])\n",
    "    \n",
    "    #将定义好的模型架构赋值给特定的变量\n",
    "    d = discriminator_model()\n",
    "    g = generator_model()\n",
    "    d_on_g = generator_containing_discriminator(g, d)\n",
    "    \n",
    "    #定义生成器模型判别器模型更新所使用的优化算法及超参数\n",
    "    d_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)\n",
    "    g_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)\n",
    "    \n",
    "    #编译三个神经网络并设置损失函数和优化算法，其中损失函数都是用的是二元分类交叉熵函数。编译是用来配置模型学习过程的\n",
    "    g.compile(loss='binary_crossentropy', optimizer=\"SGD\")\n",
    "    d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)\n",
    "    \n",
    "    #前一个架构在固定判别器的情况下训练了生成器，所以在训练判别器之前先要设定其为可训练。\n",
    "    d.trainable = True\n",
    "    d.compile(loss='binary_crossentropy', optimizer=d_optim)\n",
    "    \n",
    "    #下面在满足epoch条件下进行训练\n",
    "    for epoch in range(30):\n",
    "        print(\"Epoch is\", epoch)\n",
    "        \n",
    "        #计算一个epoch所需要的迭代数量，即训练样本数除批量大小数的值取整；其中shape[0]就是读取矩阵第一维度的长度\n",
    "        print(\"Number of batches\", int(X_train.shape[0]/BATCH_SIZE))\n",
    "        \n",
    "        #在一个epoch内进行迭代训练\n",
    "        for index in range(int(X_train.shape[0]/BATCH_SIZE)):\n",
    "            \n",
    "            #随机生成的噪声服从均匀分布，且采样下界为-1、采样上界为1，输出BATCH_SIZE×100个样本；即抽取一个批量的随机样本\n",
    "            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))\n",
    "            \n",
    "            #抽取一个批量的真实图片\n",
    "            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]\n",
    "            \n",
    "            #生成的图片使用生成器对随机噪声进行推断；verbose为日志显示，0为不在标准输出流输出日志信息，1为输出进度条记录\n",
    "            generated_images = g.predict(noise, verbose=0)\n",
    "            \n",
    "            #每经过100次迭代输出一张生成的图片\n",
    "            if index % 100 == 0:\n",
    "                image = combine_images(generated_images)\n",
    "                image = image*127.5+127.5\n",
    "                Image.fromarray(image.astype(np.uint8)).save(\n",
    "                    \"./GAN/\"+str(epoch)+\"_\"+str(index)+\".png\")\n",
    "            \n",
    "            #将真实的图片和生成的图片以多维数组的形式拼接在一起，真实图片在上，生成图片在下\n",
    "            X = np.concatenate((image_batch, generated_images))\n",
    "            \n",
    "            #生成图片真假标签，即一个包含两倍批量大小的列表；前一个批量大小都是1，代表真实图片，后一个批量大小都是0，代表伪造图片\n",
    "            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE\n",
    "            \n",
    "            #判别器的损失；在一个batch的数据上进行一次参数更新\n",
    "            d_loss = d.train_on_batch(X, y)\n",
    "            print(\"batch %d d_loss : %f\" % (index, d_loss))\n",
    "            \n",
    "            #随机生成的噪声服从均匀分布\n",
    "            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))\n",
    "            \n",
    "            #固定判别器\n",
    "            d.trainable = False\n",
    "            \n",
    "            #计算生成器损失；在一个batch的数据上进行一次参数更新\n",
    "            g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)\n",
    "            \n",
    "            #令判别器可训练\n",
    "            d.trainable = True\n",
    "            print(\"batch %d g_loss : %f\" % (index, g_loss))\n",
    "            \n",
    "            #每100次迭代保存一次生成器和判别器的权重\n",
    "            if index % 100 == 9:\n",
    "                g.save_weights('generator', True)\n",
    "                d.save_weights('discriminator', True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate(BATCH_SIZE, nice= False ):\n",
    "    #训练完模型后，可以运行该函数生成图片\n",
    "    g = generator_model()\n",
    "    g.compile(loss='binary_crossentropy', optimizer=\"SGD\")\n",
    "    g.load_weights('generator')\n",
    "    if nice:\n",
    "        d = discriminator_model()\n",
    "        d.compile(loss='binary_crossentropy', optimizer=\"SGD\")\n",
    "        d.load_weights('discriminator')\n",
    "        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))\n",
    "        generated_images = g.predict(noise, verbose=1)\n",
    "        d_pret = d.predict(generated_images, verbose=1)\n",
    "        index = np.arange(0, BATCH_SIZE*20)\n",
    "        index.resize((BATCH_SIZE*20, 1))\n",
    "        pre_with_index = list(np.append(d_pret, index, axis=1))\n",
    "        pre_with_index.sort(key=lambda x: x[0], reverse=True)\n",
    "        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)\n",
    "        nice_images = nice_images[:, :, :, None]\n",
    "        for i in range(BATCH_SIZE):\n",
    "            idx = int(pre_with_index[i][1])\n",
    "            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]\n",
    "        image = combine_images(nice_images)\n",
    "    else:\n",
    "        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))\n",
    "        generated_images = g.predict(noise, verbose=0)\n",
    "        image = combine_images(generated_images)\n",
    "    image = image*127.5+127.5\n",
    "    Image.fromarray(image.astype(np.uint8)).save(\n",
    "        \"./GAN/generated_image.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Horatio\\AppData\\Local\\conda\\conda\\envs\\tensorflow\\lib\\site-packages\\ipykernel_launcher.py:3: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=1024, input_dim=100)`\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "128/132 [============================>.] - ETA: 0s\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b"
     ]
    }
   ],
   "source": [
    "generate(132)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
